#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
基础模型类，负责量化模型加载与推理。
移植自 legacy src/detection_core/base_model.py。
"""

import logging
from time import time

import cv2
import numpy as np
from hobot_dnn import pyeasy_dnn as dnn

logger = logging.getLogger("RDK_YOLO")


class BaseModel:
    """
    基础模型类，封装模型加载、输入输出张量管理、图像预处理、推理等通用功能。
    """

    def __init__(self, model_file: str) -> None:
        """
        初始化模型，加载量化模型文件。

        Args:
            model_file (str): 模型文件路径
        """
        try:
            begin_time = time()
            self.quantize_model = dnn.load(model_file)
            # logger.debug("\033[1;31m" + "Load D-Robotics Quantize model time = %.2f ms"%(1000*(time() - begin_time)) + "\033[0m")
        except Exception as e:
            logger.error("❌ Failed to load model file: %s", model_file)
            logger.error("You can download the model file from the following docs: ./models/download.md")
            logger.error(e)
            raise

        # 打印输入张量信息
        logger.info("\033[1;32m" + "-> input tensors" + "\033[0m")
        for i, quantize_input in enumerate(self.quantize_model[0].inputs):
            logger.info(
                "input[%d], name=%s, type=%s, shape=%s",
                i,
                quantize_input.name,
                quantize_input.properties.dtype,
                quantize_input.properties.shape,
            )

        # 打印输出张量信息
        logger.info("\033[1;32m" + "-> output tensors" + "\033[0m")
        for i, quantize_input in enumerate(self.quantize_model[0].outputs):
            logger.info(
                "output[%d], name=%s, type=%s, shape=%s",
                i,
                quantize_input.name,
                quantize_input.properties.dtype,
                quantize_input.properties.shape,
            )

        # 记录模型输入的高和宽
        self.model_input_height, self.model_input_weight = self.quantize_model[0].inputs[0].properties.shape[2:4]

    def resizer(self, img: np.ndarray) -> np.ndarray:
        """
        将输入图像 resize 到模型输入大小，并记录缩放比例。
        """
        img_h, img_w = img.shape[0:2]
        self.y_scale, self.x_scale = img_h / self.model_input_height, img_w / self.model_input_weight
        return cv2.resize(img, (self.model_input_height, self.model_input_weight), interpolation=cv2.INTER_NEAREST)

    def bgr2nv12(self, bgr_img: np.ndarray) -> np.ndarray:
        """
        将 BGR 格式图像转换为 NV12 格式（YUV420 半平面），用于模型推理。
        """
        begin_time = time()
        bgr_img = self.resizer(bgr_img)
        height, width = bgr_img.shape[0], bgr_img.shape[1]
        area = height * width
        yuv420p = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2YUV_I420).reshape((area * 3 // 2,))
        y = yuv420p[:area]
        uv_planar = yuv420p[area:].reshape((2, area // 4))
        uv_packed = uv_planar.transpose((1, 0)).reshape((area // 2,))
        nv12 = np.zeros_like(yuv420p)
        nv12[:height * width] = y
        nv12[height * width:] = uv_packed

        # logger.debug("\033[1;31m" + f"bgr8 to nv12 time = {1000*(time() - begin_time):.2f} ms" + "\033[0m")
        return nv12

    def forward(self, input_tensor: np.ndarray):
        """
        执行模型推理。
        """
        begin_time = time()
        quantize_outputs = self.quantize_model[0].forward(input_tensor)
        # logger.debug("\033[1;31m" + f"forward time = {1000*(time() - begin_time):.2f} ms" + "\033[0m")
        return quantize_outputs

    def c2numpy(self, outputs):
        """
        将 C 类型的推理输出转为 numpy 数组。
        """
        begin_time = time()
        outputs = [dnn_tensor.buffer for dnn_tensor in outputs]
        # logger.debug("\033[1;31m" + f"c to numpy time = {1000*(time() - begin_time):.2f} ms" + "\033[0m")
        return outputs
